
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
    
    
def data_high_generator(images, labels, e,flip_rate):
    def torch_bernoulli(p, size):
        return (torch.rand(size) < p).float()
    def torch_xor(a, b):
        return (a-b).abs()
    images = images.reshape((-1, 28, 28))[:, ::2, ::2]
    labels = (labels >2).float()
    labels = torch_xor(labels, torch_bernoulli(flip_rate, len(labels)))
    colors = torch_xor(labels, torch_bernoulli(e, len(labels)))
    images = torch.stack([images, images], dim=1)
    images[torch.tensor(range(len(images))), (1-colors).long(), :, :] *= 0
    
    env_labels = labels.clone()[:, None]
    env_id = (torch.zeros(labels.shape[0]).float()).view(-1,1)

    if e == 0.3:
        env_id += 2*(torch.ones(labels.shape[0]).float()).view(-1,1)
    elif e == 0.5:
        env_id += 4*(torch.ones(labels.shape[0]).float()).view(-1,1)
    elif e == 0.7:
        env_id += 6*(torch.ones(labels.shape[0]).float()).view(-1,1)
    elif e == 0.9:
        env_id += 8*(torch.ones(labels.shape[0]).float()).view(-1,1)     
    env_and_labels = env_labels + env_id
    width=5
    height=5
    rows = 2
    cols = 2
    axes=[]
    fig=plt.figure(figsize=(15,10))
    for a in range(8*8):
        b = images[a].float() / 255.
        A = torch.zeros((1,14,14))
        C = torch.cat([b,A])
        #print(C.shape)
        axes.append( fig.add_subplot(8, 8, a+1) )
        subplot_title=("label:"+str(int(labels[a].detach().numpy().copy())))
        axes[-1].set_title(subplot_title)  
        A = C.T 
        B = np.rot90(A, 3)
        D = np.fliplr(B) 
        plt.imshow(D)
    fig.tight_layout()    
    fig.subplots_adjust(wspace=0.5, hspace=0.7)
    plt.title('data1')
    
    #If we wanna visualize Colored MNIST data, remove the following #
    #plt.savefig('highimg_e={}_filp={}.png'.format(e ,flip_rate))
    
    plt.show()
    
    return {
      'images': (images.float() / 255.).cuda(),
      'labels': labels[:, None].cuda(),  'env_labels':env_and_labels[:, None].cuda()
    }

def data_generator(images, labels,e ,flip_rate):
    def torch_bernoulli(p, size):
        return (torch.rand(size) < p).float()
    def torch_xor(a, b):
        return (a-b).abs()
    images = images.reshape((-1, 28, 28))[:, ::2, ::2]
    blabels=torch.zeros(len(labels))
    for i in range(len(blabels)):
        if labels[i] >4:
            blabels[i]=2
        elif labels[i] ==4 or labels[i] ==3:
            blabels[i]=1
        else:blabels[i] =0
    X = torch_bernoulli(flip_rate, len(labels))
    
    for i in range(len(labels)):
        if X[i] == 1 and  blabels[i] == 2:
            blabels[i] = torch_bernoulli(0.5, 1)
        elif X[i] == 1 and blabels[i] == 1:
            blabels[i] = 2*torch_bernoulli(0.5, 1)
        elif X[i] == 1 and blabels[i] == 0:
            blabels[i] = 1+ torch_bernoulli(0.5, 1)
    bblabels =  blabels.clone()
    bblabels = (bblabels >1).float()
    colors = torch_xor(bblabels, torch_bernoulli(e, len(labels)))
    images = torch.stack([images, images], dim=1)
    images[torch.tensor(range(len(images))), (1-colors).long(), :, :] *= 0
    
    
    width=5
    height=5
    rows = 2
    cols = 2
    axes=[]
    fig=plt.figure(figsize=(15,10))
    for a in range(8*8):
        b = images[a].float() / 255.
        A = torch.zeros((1,14,14))
        C = torch.cat([b,A])
        #print(C.shape)
        axes.append( fig.add_subplot(8, 8, a+1) )
        subplot_title=("label:"+str(int(blabels[a].detach().numpy().copy())))
        axes[-1].set_title(subplot_title)  
        A = C.T 
        B = np.rot90(A, 3)
        D = np.fliplr(B) 
        plt.imshow(D)
    fig.tight_layout()    
    fig.subplots_adjust(wspace=0.5, hspace=0.7)
    plt.title('data1')
    
    #If we wanna visualize Colored MNIST data, remove the following #
    
    #plt.savefig('img_e={}_filp={}.png'.format(e ,flip_rate))
    
    
    plt.show()
    return {
      'images': (images.float() / 255.).cuda(),
      'labels': blabels[:, None].cuda()
    }


def data_loader(envs, sample_size):
    sample = []
    fig = plt.figure(figsize = (15, 13))
    fig.suptitle("Training and Test data on e={}.png".format(envs))
    
    #Train data generating
    data_train = data_generator(sample_size,envs)
    ax = fig.add_subplot(1,2, 1)
    ax.title.set_text('Training data')
    ax.set_ylim(-50, 50) 
    ax.set_xlim(-50, 50) 
    w = data_train['images']*50
    labels = data_train['labels']
    ax.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
    ax.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
    ax.scatter(w[labels.reshape(sample_size)==2][:,0],w[labels.reshape(sample_size)==2][:,1],color="g",label="label:2",s=10)
    sample.append(data_train)
    
    #Test data generating
    data_test = data_generator(sample_size,-envs)
    bx = fig.add_subplot(1,2, 2)
    bx.title.set_text('Test data')
    bx.set_ylim(-50, 50) 
    bx.set_xlim(-50, 50) 
    w = data_test['images']*50
    labels = data_test['labels']
    bx.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
    bx.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
    bx.scatter(w[labels.reshape(sample_size)==2][:,0],w[labels.reshape(sample_size)==2][:,1],color="g",label="label:2",s=10)
    sample.append(data_test)
    
    
    
    plt.legend()
    fig.savefig("Training and Test data on e={}.png".format(envs))
    plt.show()
    return sample


def high_data_loader(env_number, envs_array, sample_size):
    sample = []
    fig = plt.figure(figsize = (15, 13))
    fig.suptitle("Visualization of high data under env={}".format(envs_array))
    for i in range(envs_array.shape[0]):
        data = data_high_generator(sample_size,envs_array[i])
        data['env_labels'] = data['labels'] + 2*i
        
        ax = fig.add_subplot(3,2, i+1)
        ax.set_ylim(-150, 150) 
        ax.set_xlim(-50, 50) 
        w = data['images']*50
        labels = data['labels']
        #print(labels)
        ax.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
        ax.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
        
        sample.append(data)
    plt.legend()
    fig.savefig("high_data env={}.png".format(envs_array))
    plt.show()
    return sample


